﻿
using CNative.DbUtils;
using CNative.Utilities;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace CNative.DbUtils
{
    /// <summary>
    /// 数据库操作抽象帮助类
    /// </summary>
    public class DbHelper : IDbHelper
    {
        #region var
        private string _dbName;

        private string _ConnectString = null;
        /// <summary>
        /// 当前数据库连接字符串
        /// </summary>
        public string ConnectString
        {
            get
            {
                if (_ConnectString == null || _ConnectString == "")
                    SetConnectString(_dbName);
                return _ConnectString;
            }
            set { _ConnectString = value; }
        }

        /// <summary>
        /// 是否启用主从分离模式
        /// </summary>
        public bool IsUseMasterSlaveSeparation { get; set; } = false;
        private string _MasterConnectString = null;
        /// <summary>
        /// 主数据库连接字符串
        /// </summary>
        public string MasterConnectString
        {
            get
            {
                if (_MasterConnectString == null || _MasterConnectString == "")
                    SetConnectString(_dbName);
                return _MasterConnectString;
            }
            set { _MasterConnectString = value; }
        }
        /// <summary>
        /// 从数据库连接字符串集合
        /// </summary>
        public List<string> SlaveConnectStrings { set; get; }
        /// <summary>
        /// 当前是否主从数据库模式
        /// </summary>
        protected virtual bool IsMasterSlaveSeparation
        {
            get
            {
                return IsUseMasterSlaveSeparation && SlaveConnectStrings != null && SlaveConnectStrings.Any();
            }
        }
        public int? CommandTimeout { get; set; }
        /// <summary>
        /// 数据库类型
        /// </summary>
        protected DatabaseType _dbType = DatabaseType.SqlServer;
        /// <summary>
        /// 数据库类型
        /// </summary>
        public DatabaseType DBType
        {
            get
            {
                if (_SqlDbProvider == null || _ConnectString == null || _ConnectString == "")
                    SetConnectString(_dbName);
                return _dbType;
            }
        }
        /// <summary>
        /// 数据库提供者
        /// </summary>
        private BaseProvider _SqlDbProvider = null;
        /// <summary>
        /// 数据库提供者
        /// </summary>
        public BaseProvider SqlDbProvider
        {
            get
            {
                if (_SqlDbProvider == null || _ConnectString == null || _ConnectString == "")
                    SetConnectString(_dbName);
                return _SqlDbProvider;
            }
        }
        /// <summary>
        /// 数据库提供工厂
        /// </summary>
        protected DbProviderFactory _dbProviderFactory = null;
        protected IDbTransaction _transaction = null;

        /// <summary>
        /// 数据库提供者字典
        /// </summary>
        protected static Dictionary<DatabaseType, Type> DicDbProvider = new Dictionary<DatabaseType, Type>();
        /// <summary>
        /// 数据库类型字典
        /// </summary>
        protected static Dictionary<string, DatabaseType> DicDBType = new Dictionary<string, DatabaseType>(StringComparer.OrdinalIgnoreCase);
        #endregion
        #region DbHelper
        public DbHelper(string dbName)
        {
            _dbName = dbName;
        }
        static DbHelper()
        {
            var databaseTypes = Enum.GetNames(typeof(DatabaseType));
            foreach(var databaseTypeN in databaseTypes)
            {
                if (Enum.TryParse(databaseTypeN, out DatabaseType databaseType))
                {
                    DicDBType[databaseTypeN] = databaseType;
                    DicDBType["system.data." + databaseTypeN] = databaseType;
                }
            }
            DicDBType["System.Data.SqlClient"] = DatabaseType.SqlServer;
            DicDBType["System.Data.OracleClient" ] = DatabaseType.Oracle;//依赖于oracle官方驱动，需要另外安装oracle客户端
            DicDBType["Oracle.DataAccess.Client"] = DatabaseType.Oracle;//Oracle数据库，官方非托管驱动，限制比较多
            DicDBType["Oracle.ManagedDataAccess.Client"] = DatabaseType.Oracle;//Oracle官方托管驱动,10g以下版本不支持，无任何依赖
            //----------------------------------------------------------------------------------
            DicDbProvider[DatabaseType.SqlServer] = typeof(SqlServerProvider);
            DicDbProvider[DatabaseType.MySql] = typeof(MySqlProvider);
            DicDbProvider[DatabaseType.Oracle] = typeof(OracleProvider);
            DicDbProvider[DatabaseType.Sqlite] = typeof(SqliteProvider);
            DicDbProvider[DatabaseType.MsAccess] = typeof(MsAccessProvider);
#if !NET40
            DicDbProvider[DatabaseType.PostgreSql] = typeof(PostgreSqlProvider);
#endif
        }
        #endregion

        #region 开始数据库事务
        /// <summary>
        /// 开始数据库事务
        /// </summary>
        /// <param name="isol">指定连接的事务锁定行为</param>
        public IDbTransaction BeginTransaction(System.Data.IsolationLevel isol = IsolationLevel.ReadCommitted)
        {
            _transaction = GetDbConnection(null).BeginTransaction(isol);
            return _transaction;
        }
        #endregion

        #region 提交数据库事务
        /// <summary>
        /// 提交数据库事务
        /// </summary>
        public void Commit()
        {
            if (_transaction != null)
            {
                _transaction.Commit();
                try
                {
                    if (_transaction.Connection != null)
                    {
                        _transaction.Connection?.Close();
                        _transaction.Connection?.Dispose();
                    }
                }
                catch { }
                _transaction.Dispose();
                _transaction = null;
            }
        }
        #endregion

        #region 回滚事务
        /// <summary>
        ///  从挂起状态回滚事务
        /// </summary>
        public void Rollback()
        {
            try
            {
                if (_transaction != null)
                {
                    _transaction.Rollback();

                    if (_transaction.Connection != null && _transaction.Connection != null
                        && _transaction.Connection.State == ConnectionState.Open)
                    {
                        _transaction.Connection?.Close();
                        _transaction.Connection?.Dispose();
                    }
                    _transaction.Dispose();
                    _transaction = null;
                }
            }
            catch { }
        }
        #endregion

        #region protected virtual
        /// <summary>
        /// 获取一个连接
        /// </summary>
        /// <returns>返回一个合适的连接</returns>
        protected virtual IDbConnection GetDbConnection()
        {
            var connection = _dbProviderFactory.CreateConnection();
            connection.ConnectionString = ConnectString;
            if (connection.State != ConnectionState.Open)
            {
                connection.Open();
            }
            return connection;
        }
        protected virtual IDbConnection GetDbConnection(SqlEntity sql)
        {
            if (_dbProviderFactory == null) SetConnectString(_dbName);

            //如果transaction为空，则创建一个新的连接,否则就使用transaction的连接，以实现事务
            if (_transaction == null)
            {
                SetSlaveConnectString(sql);
                return GetDbConnection();
            }
            else
            {
                return _transaction.Connection ;
            }
        }

        protected virtual IDbCommand CreateCommand(SqlEntity sql)
        {
            if (_dbProviderFactory == null) SetConnectString(_dbName);

            IDbCommand cmd = _dbProviderFactory.CreateCommand();
            //如果transaction为空，则创建一个新的连接,否则就使用transaction的连接，以实现事务
            if (_transaction == null)
            {
                cmd.Connection = GetDbConnection(sql);
            }
            else
            {
                cmd.Transaction = _transaction;
                cmd.Connection = _transaction.Connection;
            }
            cmd.CommandType = sql.CommandType;
            cmd.CommandText = sql.Sql;
            cmd.CommandTimeout = CommandTimeout ?? 45;

            BindParameters(cmd as DbCommand, sql);
            return cmd;
        }
        protected virtual void CloseCommand(IDbCommand cmd)
        {
            try
            {
                if (_transaction == null && cmd.Connection != null && cmd.Connection.State == ConnectionState.Open)
                {
                    cmd.Connection.Close();
                    cmd.Connection.Dispose();
                    cmd.Dispose();
                }
            }
            catch (Exception err)
            {
                // throw err;
            }
        }
        #endregion

        #region Execute
        public virtual bool Execute(SqlEntity sql)
        {
            var cmd = CreateCommand(sql);
            try
            {
                var ret = cmd.ExecuteNonQuery();
               
                CloseCommand(cmd);
                return true;
            }
            catch (Exception e)
            {
                Rollback();
                CloseCommand(cmd);
                throw e;
            }
        }
        //public bool Execute(string sql, List<IDataParameter> parameters)
        //{
        //    var sqlEntity = CreateSqlEntity();
        //    sqlEntity.Sql = sql;
        //    sqlEntity.Parameters = parameters;
        //    return Execute(sqlEntity);
        //}

        public virtual bool Execute(List<SqlEntity> sqlList)
        {
            if (sqlList == null) return false;

            if (sqlList.Count > 1) BeginTransaction();
            sqlList.ForEach(sql => Execute(sql));
            if (sqlList.Count > 1) Commit();

            return true;
        }

        public void ExecuteOneWay(SqlEntity sql)
        {
            Execute(sql);
        }
        #endregion
        #region Query
        public virtual DataSet QueryDataSet(SqlEntity sql)
        {
            var cmd = CreateCommand(sql);
            try
            {
                using (var adapt = _dbProviderFactory.CreateDataAdapter())
                {
                    adapt.SelectCommand = cmd as DbCommand;
                    var ds = new DataSet();
                    adapt.Fill(ds);

                    CloseCommand(cmd);
                    return ds;
                }
            }
            catch (Exception err)
            {
                Rollback();
                CloseCommand(cmd);
                throw err;
            }
        }

        public DataTable QueryDataTable(SqlEntity sql)
        {
            var ds = QueryDataSet(sql);
            if (ds != null && ds.Tables.Count > 0)
                return QueryDataSet(sql).Tables[0];
            else return null;
        }

        public List<T> Query<T>(SqlEntity sql) where T : class
        {
            var dt = QueryDataTable(sql);
            if (dt != null && dt.Rows.Count > 0)
                return EntityHelper.DataTableToList<T>(dt);
            else return null;
        }

        public List<T> Query<T>(string sql, List<IDataParameter> parameters) where T : class
        {
            var sqlEntity = CreateSqlEntity();
            sqlEntity.Sql = sql;
            sqlEntity.Parameters = parameters;
            return Query<T>(sqlEntity);
        }
        #endregion
        #region GetSingle
        public virtual T GetSingle<T>(SqlEntity sql)
        {
            var cmd = CreateCommand(sql);
            try
            {
                var res = cmd.ExecuteScalar();
                CloseCommand(cmd);
                return (T)Convert.ChangeType(res, typeof(T));
            }
            catch (Exception err)
            {
                Rollback();
                CloseCommand(cmd);
                throw err;
            }
        }
        public T GetSingleRow<T>(SqlEntity sql) where T : class
        {
            var dt = Query<T>(sql);
            if (dt != null && dt.Count > 0)
                return dt.FirstOrDefault();
            else return null;
        }

        #endregion

        #region BindParameters
        protected virtual void BindParameters(DbCommand cmd, SqlEntity sql)
        {
            if (sql.Parameters == null || sql.Parameters.Count == 0) return;
            cmd.Parameters.AddRange(sql.Parameters.Clone_().ToArray());
        }
        #endregion

        #region CreateSqlEntity
        /// <summary>
        /// 
        /// </summary>
        /// <returns></returns>
        public SqlEntity CreateSqlEntity()
        {
            return new SqlEntity(this);
        }
        #endregion

        #region GetDBType/GetConnectString
        /// <summary>
        /// 通过连接名或连接字符串获取连接字符串
        /// </summary>
        /// <param name="nameOrconStr">连接名或者连接字符串</param>
        /// <returns></returns>
        protected void SetConnectString(string DBname)
        {
            if (CNative.Utilities.ConfigurationHelper.ConnectionStrings == null || CNative.Utilities.ConfigurationHelper.ConnectionStrings.Count == 0)
                throw new Exception("appsettings.json文件中未找到[ConnectionStrings]配置节点");
            if (!CNative.Utilities.ConfigurationHelper.ConnectionStrings.ContainsKey(DBname))
                throw new Exception("未找到名称为[" + DBname + "]的连接字符串");

            var configSettings = CNative.Utilities.ConfigurationHelper.ConnectionStrings[DBname];
            if (configSettings == null)
                throw new Exception("未找到名称为" + DBname + "的数据库配置节点");

            _MasterConnectString = configSettings.ConnectionString;
            _ConnectString = _MasterConnectString;
            CommandTimeout = configSettings.CommandTimeout;

            var provider = GetDbProvider(configSettings.ProviderName);
            _dbType = provider.Item1;
            _SqlDbProvider = provider.Item2;
            _dbProviderFactory = _SqlDbProvider.DbProviderFactory;

            InitSlaveConnectStrings(configSettings);
        }
        protected virtual Tuple<DatabaseType, BaseProvider> GetDbProvider(string providerName)
        {
            var _dbType = DatabaseType.SqlServer;
            if (!DicDBType.TryGetValue(providerName.Trim().ToLower(), out _dbType))
            {
                var provider = FastReflection.FastInstance<BaseProvider>(providerName, this);
                if (provider != null)
                {
                    return Tuple.Create(provider.DBType, provider);
                }
                else { throw new Exception("未找到名称为" + providerName + "的数据库提供者"); }
                //_dbType = DatabaseType.SqlServer;
            }
            #region switch (providerName.Trim().ToLower())
            //switch (providerName.Trim().ToLower())
            //{
            //    default:
            //        _dbType = DatabaseType.SqlServer;
            //        break;
            //    case "system.data.sqlclient":
            //        _dbType = DatabaseType.SqlServer;
            //        break;
            //    case "system.data.oracleclient":
            //        _dbType = DatabaseType.Oracle;
            //        break;
            //    case "system.data.mysql":
            //        _dbType = DatabaseType.MySql;
            //        break;
            //    case "system.data.sqlite":
            //        _dbType = DatabaseType.Sqlite;
            //        break;
            //    case "system.data.npgsql":
            //        _dbType = DatabaseType.PostgreSql;
            //        break;
            //    case "system.data.access":
            //        _dbType = DatabaseType.MsAccess;
            //        break;
            //    case "system.data.dameng":
            //        _dbType = DatabaseType.Dameng;
            //        break;
            //    case "system.data.firebird":
            //        _dbType = DatabaseType.Firebird;
            //        break;
            //    case "system.data.kingbasees":
            //        _dbType = DatabaseType.KingbaseES;
            //        break;
            //    case "system.Odbc":
            //        _dbType = DatabaseType.Odbc;
            //        break;
            //    case "system.data.odbcsqlserver":
            //        _dbType = DatabaseType.OdbcSqlServer;
            //        break;
            //}
            #endregion
            return Tuple.Create(_dbType, GetSqlDbProviderFactory(_dbType));
        }
        /// <summary>
        /// 获取提供工厂
        /// </summary>
        /// <param name="dbType">数据库类型</param>
        /// <returns></returns>
        protected virtual BaseProvider GetSqlDbProviderFactory(DatabaseType dbType)
        {
            BaseProvider provider = null;
            if (DicDbProvider.TryGetValue(dbType, out Type providerType))
            {
                provider = providerType.FastInstance<BaseProvider>(this);
            }
            else
            {
                throw new Exception($"未实现{dbType.ToString()}数据库提供者");
            }
            #region switch (dbType)
            //            switch (dbType)
            //            {
            //                case DatabaseType.SqlServer: provider = new SqlServerProvider(this); break;
            //                case DatabaseType.MySql: provider = new MySqlProvider(this); break;
            //                //case DatabaseType.PostgreSql: provider = new SqlServerProvider(this); break;
            //                case DatabaseType.Oracle: provider = new OracleProvider(this); break;
            //                case DatabaseType.Sqlite: provider = new SqliteProvider(this); break;
            //                case DatabaseType.MsAccess: provider = new MsAccessProvider(this); break;
            //#if !NET40
            //                case DatabaseType.PostgreSql: provider = new PostgreSqlProvider(this); break;
            //#endif
            //                default: throw new Exception("请传入有效的数据库！");
            //            }
            #endregion
            return provider;
        }
        #endregion
        #region AddSqlDbProviderMapping
        /// <summary>
        /// 添加数据库提供者
        /// </summary>
        /// <param name="dbType">SqlServer</param>
        /// <param name="providerType">typeof(SqlServerProvider)</param>
        /// <returns></returns>
        public static bool AddSqlDbProvider(DatabaseType dbType, Type providerType)
        {
            if (DicDbProvider == null) DicDbProvider = new Dictionary<DatabaseType, Type>();
            DicDbProvider[dbType] = providerType;
            return true;
        }
        /// <summary>
        /// 添加数据库类型对应关系
        /// </summary>
        /// <param name="dbTypeStr">System.Data.SqlClient</param>
        /// <param name="dbType">SqlServer</param>
        /// <returns></returns>
        public static bool AddSqlDbTypeMapping(string dbTypeStr, DatabaseType dbType)
        {
            if (dbTypeStr.IsNullOrEmpty()) return false;
            if (DicDBType == null) DicDBType = new Dictionary<string, DatabaseType>(StringComparer.OrdinalIgnoreCase);
            DicDBType[dbTypeStr] = dbType;
            return true;
        }
        /// <summary>
        /// 添加数据库类型和提供者对应关系
        /// </summary>
        /// <param name="dbTypeStr">System.Data.SqlClient</param>
        /// <param name="dbType">SqlServer</param>
        /// <param name="providerType">typeof(SqlServerProvider)</param>
        /// <returns></returns>
        public static bool AddSqlDbProviderMapping(string dbTypeStr, DatabaseType dbType, Type providerType)
        {
            AddSqlDbTypeMapping(dbTypeStr, dbType);
            return AddSqlDbProvider(dbType, providerType);
        }
        #endregion

        //---------------------------------------------------------------------------------------------------------
        #region Ping
        /// <summary>
        /// 判断当前连接字符串是否有效
        /// </summary>
        /// <param name="connectionString"></param>
        /// <param name="isThrow"></param>
        /// <returns></returns>
        public virtual bool Ping(string connectionString = "", bool isThrow = false)
        {
            if (connectionString.IsNullOrEmpty())
                connectionString = this.ConnectString;

            var connection = this.SqlDbProvider.DbProviderFactory.CreateConnection();
            connection.ConnectionString = connectionString;
            if (connection.State != ConnectionState.Open)
                connection.Open();

            return Helpers.DbConnectionExtensions.Ping(connection, this.SqlDbProvider.FullSqlDateNow, isThrow);
        }
        #endregion

        #region SetSlaveConnectString
        /// <summary>
        /// 初始化从数据库连接集合
        /// </summary>
        /// <param name="connectionStringSetting"></param>
        protected void InitSlaveConnectStrings(ConnectionStringSetting connectionStringSetting)
        {
            if (connectionStringSetting.IsNullOrEmpty() || connectionStringSetting.SlaveConnectStrings.IsNullOrEmpty_())
                return;

            this.IsUseMasterSlaveSeparation = connectionStringSetting.IsUseMasterSlaveSeparation;

            if (SlaveConnectStrings == null) SlaveConnectStrings = new List<string>();
            SlaveConnectStrings.Clear();
            foreach (var DBname in connectionStringSetting.SlaveConnectStrings)
            {
                try
                {
                    if (DBname.IsNullOrEmpty())
                        continue;

                    var ConnStr = DBname;
                    var css = CNative.Utilities.ConfigurationHelper.ConnectionStrings[DBname];
                    if (css != null && css.ConnectionString.IsNotNullOrEmpty())
                    {
                        ConnStr = css.ConnectionString;
                    }
                    if (!SlaveConnectStrings.Any(a => EqualsConnectionString(a, ConnStr))
                        && Ping(ConnStr))
                    {
                        SlaveConnectStrings.Add(ConnStr);
                    }
                }
                catch { }
            }
        }
        /// <summary>
        /// 设置从数据库连接
        /// </summary>
        /// <param name="sql"></param>
        protected virtual void SetSlaveConnectString(SqlEntity sql)
        {
            if (sql != null && sql.CommandType == CommandType.Text && this._transaction == null && this.IsMasterSlaveSeparation && IsRead(sql.Sql))
            {
                var saves = this.SlaveConnectStrings;
                var currentIndex = saves.GetRandomIndex();
                _ConnectString = saves[currentIndex];
            }
            else
            {
                _ConnectString = _MasterConnectString;
            }
        }

        protected virtual bool IsRead(string sql)
        {
            var sqlLower = sql.ToLower();
            var result = System.Text.RegularExpressions.Regex.IsMatch(sqlLower, "[ ]*select[ ]") && !System.Text.RegularExpressions.Regex.IsMatch(sqlLower, "[ ]*insert[ ]|[ ]*update[ ]|[ ]*delete[ ]");
            return result;
        }
        protected virtual bool EqualsConnectionString(string connectionString1, string connectionString2)
        {
            if (connectionString1.IsNullOrEmpty() || connectionString2.IsNullOrEmpty())
                return false;

            var connectionString1Array = connectionString1.Split(';');
            var connectionString2Array = connectionString2.Split(';');
            var result = connectionString1Array.Except(connectionString2Array);
            return result.Count() == 0;
        }
        #endregion
    }
}
